Skip to content

Conversation

Ronald1995
Copy link
Contributor

@Ronald1995 Ronald1995 commented Jul 22, 2025

What this PR does / why we need it?

it'll execute allreduce and malmul seperately in vllm RowParallelLinear forward funcion, this function use torch_npu.npu_mm_all_reduce_base to execute allreduce and matmul in a fused kernel way. this will gain a 20% performance
promotion in eager mode.

Does this PR introduce any user-facing change?

this PR introduce a new env VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE to control whether enable the feature or not.

How was this patch tested?

the patch is tested by adding a new test file test_patch_linear.py to guard the ut

@Ronald1995 Ronald1995 force-pushed the feature_allreduce branch 4 times, most recently from 84a48de to 1a2358b Compare July 24, 2025 06:29
@ApsarasX
Copy link
Collaborator

Do you have any relevant performance data?

@Ronald1995 Ronald1995 force-pushed the feature_allreduce branch 2 times, most recently from 9f57e65 to 73b7ef7 Compare July 24, 2025 07:31
@Ronald1995 Ronald1995 changed the title WIP: implement the fusion of allreduce and matmul in prefill phase when tp is enabled [Feature]: implement the fusion of allreduce and matmul in prefill phase when tp is enabled Jul 24, 2025
@184603418
Copy link

Hello, I have use this PR, based on the version as folow:
pip list | grep torch
torch 2.5.1
torch-npu 2.5.1.post1.dev20250619
and I install the vllm with main branch and vllm_ascend with your Ronald1995:feature_allreduce branch, but when I start the vllm serve, I get the following error, please help to check if my torch/torch_npu version is not correct? Thanks very much!!!
ERROR 07-24 19:51:24 [multiproc_executor.py:140] Worker proc VllmWorker-1 died unexpectedly, shutting down executor.
Process EngineCore_0:
Traceback (most recent call last):
File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 638, in run_engine_core
raise e
File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 625, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 443, in init
super().init(vllm_config, executor_class, log_stats,
File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 87, in init
self._initialize_kv_caches(vllm_config)
File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 159, in _initialize_kv_caches
self.model_executor.determine_available_memory())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/l00847538/code/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
output = self.collective_rpc("determine_available_memory")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 237, in collective_rpc
result = get_response(w, dequeue_timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 224, in get_response
raise RuntimeError(
RuntimeError: Worker failed with error 'Dict key of type <class 'torch._dynamo.variables.lazy.LazyVariableTracker'>. Key: ProcessGroupVariable()

from user code:
File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward
hidden_states, residual = layer(
File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward
hidden_states = self.self_attn(
File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 184, in forward
output, _ = self.o_proj(attn_output)
File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 84, in forward
output = self.calc_output(input_parallel)
File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 122, in calc_output
hcomm_info = self.get_hcomm_info(tp_group)
File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 57, in get_hcomm_info
global_rank = torch.distributed.get_global_rank(group, rank)
File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 949, in get_global_rank
if group not in _world.pg_group_ranks:
File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 621, in pg_group_ranks
return _pg_group_ranks

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

@Ronald1995
Copy link
Contributor Author

Hello, I have use this PR, based on the version as folow: pip list | grep torch torch 2.5.1 torch-npu 2.5.1.post1.dev20250619 and I install the vllm with main branch and vllm_ascend with your Ronald1995:feature_allreduce branch, but when I start the vllm serve, I get the following error, please help to check if my torch/torch_npu version is not correct? Thanks very much!!! ERROR 07-24 19:51:24 [multiproc_executor.py:140] Worker proc VllmWorker-1 died unexpectedly, shutting down executor. Process EngineCore_0: Traceback (most recent call last): File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 638, in run_engine_core raise e File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 625, in run_engine_core engine_core = EngineCoreProc(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 443, in init super().init(vllm_config, executor_class, log_stats, File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 87, in init self._initialize_kv_caches(vllm_config) File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 159, in _initialize_kv_caches self.model_executor.determine_available_memory()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory output = self.collective_rpc("determine_available_memory") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 237, in collective_rpc result = get_response(w, dequeue_timeout) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 224, in get_response raise RuntimeError( RuntimeError: Worker failed with error 'Dict key of type <class 'torch._dynamo.variables.lazy.LazyVariableTracker'>. Key: ProcessGroupVariable()

from user code: File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward hidden_states, residual = layer( File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward hidden_states = self.self_attn( File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 184, in forward output, _ = self.o_proj(attn_output) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 84, in forward output = self.calc_output(input_parallel) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 122, in calc_output hcomm_info = self.get_hcomm_info(tp_group) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 57, in get_hcomm_info global_rank = torch.distributed.get_global_rank(group, rank) File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 949, in get_global_rank if group not in _world.pg_group_ranks: File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 621, in pg_group_ranks return _pg_group_ranks

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

torch in my env is torch==2.7.1 and torch_npu==2.7.1rc1

@Ronald1995
Copy link
Contributor Author

Do you have any relevant performance data?

i validate this feature in A2.
tp=4, 2k->20k, gbs=24,eager mode.
by using vlllm benchmark, original throughout 270tps, enable the feature, throughout is improved to 340tps.

@184603418
Copy link

Do you have any relevant performance data?

i validate this feature in A2. tp=4, 2k->20k, gbs=24,eager mode. by using vlllm benchmark, original throughout 270tps, enable the feature, throughout is improved to 340tps.

Thanks very much, I want to use this feature on A3, and my target is to decease decode latency. I saw your feature is working in repfill phase, if it also works on decode phase?

@Ronald1995 Ronald1995 force-pushed the feature_allreduce branch 2 times, most recently from a54cfb1 to 4547ae7 Compare July 25, 2025 06:57
@Ronald1995 Ronald1995 force-pushed the feature_allreduce branch 4 times, most recently from 46de3a3 to 0125904 Compare July 25, 2025 09:14
Copy link

codecov bot commented Jul 25, 2025

Codecov Report

❌ Patch coverage is 94.40000% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.00%. Comparing base (df0ec55) to head (7302640).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
...m_ascend/patch/worker/patch_common/patch_linear.py 86.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1926      +/-   ##
==========================================
+ Coverage   71.73%   72.00%   +0.26%     
==========================================
  Files          96       98       +2     
  Lines       10719    10843     +124     
==========================================
+ Hits         7689     7807     +118     
- Misses       3030     3036       +6     
Flag Coverage Δ
unittests 72.00% <94.40%> (+0.26%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

wangxiyuan
wangxiyuan previously approved these changes Jul 28, 2025
@wangxiyuan
Copy link
Collaborator

Hello, I have use this PR, based on the version as folow: pip list | grep torch torch 2.5.1 torch-npu 2.5.1.post1.dev20250619 and I install the vllm with main branch and vllm_ascend with your Ronald1995:feature_allreduce branch, but when I start the vllm serve, I get the following error, please help to check if my torch/torch_npu version is not correct? Thanks very much!!! ERROR 07-24 19:51:24 [multiproc_executor.py:140] Worker proc VllmWorker-1 died unexpectedly, shutting down executor. Process EngineCore_0: Traceback (most recent call last): File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/usr/local/python3.11.10/lib/python3.11/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 638, in run_engine_core raise e File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 625, in run_engine_core engine_core = EngineCoreProc(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 443, in init super().init(vllm_config, executor_class, log_stats, File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 87, in init self._initialize_kv_caches(vllm_config) File "/home/l00847538/code/vllm/vllm/v1/engine/core.py", line 159, in _initialize_kv_caches self.model_executor.determine_available_memory()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory output = self.collective_rpc("determine_available_memory") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 237, in collective_rpc result = get_response(w, dequeue_timeout) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/l00847538/code/vllm/vllm/v1/executor/multiproc_executor.py", line 224, in get_response raise RuntimeError( RuntimeError: Worker failed with error 'Dict key of type <class 'torch._dynamo.variables.lazy.LazyVariableTracker'>. Key: ProcessGroupVariable()
from user code: File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward hidden_states, residual = layer( File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward hidden_states = self.self_attn( File "/home/l00847538/code/vllm/vllm/model_executor/models/qwen2.py", line 184, in forward output, _ = self.o_proj(attn_output) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 84, in forward output = self.calc_output(input_parallel) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 122, in calc_output hcomm_info = self.get_hcomm_info(tp_group) File "/home/l00847538/code/vllm-ascend/vllm_ascend/patch/worker/patch_common/patch_linear.py", line 57, in get_hcomm_info global_rank = torch.distributed.get_global_rank(group, rank) File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 949, in get_global_rank if group not in _world.pg_group_ranks: File "/usr/local/python3.11.10/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 621, in pg_group_ranks return _pg_group_ranks
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

torch in my env is torch==2.7.1 and torch_npu==2.7.1rc1

So this feature rely on pta 2.7.1, right? If so, I think we should merge this after pta is upgrade on main branch

@wangxiyuan wangxiyuan dismissed their stale review July 28, 2025 06:34

rely on pta 2.7.1

@Ronald1995
Copy link
Contributor Author

2.7.1

this feature doesn't rely on torch2.7.1, the npu_mm_all_reduce_base is supported since torch_npu 2.1.0. i mentions torch 2.7.1 just to tell the questioner what torch version in my env. another colleague use torch 2.5.1 to successllfully run this feature.

@ganyi1996ppo ganyi1996ppo merged commit 32a9c5f into vllm-project:main Jul 28, 2025
25 checks passed
@ApsarasX
Copy link
Collaborator

ApsarasX commented Jul 28, 2025

This PR improves performance with small batches but degrades it with large batches, specifically during decoding.

  • model: Qwen3-32B
  • launch command
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=2
export VLLM_VERSION=0.9.1
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

export PROMETHEUS_MULTIPROC_DIR=/tmp/

nohup python -m vllm.entrypoints.openai.api_server \
    --model=/home/admin/model \
    --trust-remote-code \
    --served-model-name auto \
    --load-format=prefetch_auto \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    --enforce-eager \
    --max-num-seqs 192 \
    --max-model-len 32768 \
    --max-num-batched-tokens 16384 \
    --block-size 128 \
    --gpu-memory-utilization 0.9 \
    --enable-prompt-tokens-details &> run.log &
disown

Before this PR

batch p90 TPOT(ms)
1 22.63
40 35.51
80 71.67

After this PR

batch p90 TPOT(ms)
1 18.77
40 48.53
80 74.75

@wangxiyuan @ganyi1996ppo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants